{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "19f749cb-5b9f-428c-90d6-200d6b5f487b",
   "metadata": {},
   "source": [
    "## 1. The binary cross-entropy loss"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30819d87-f75d-48b1-b497-97c90b01cb11",
   "metadata": {},
   "source": [
    "Code example companion notebook for the blog article   \n",
    "[Losses Learned -- Optimizing Negative Log-Likelihood and Cross-Entropy in PyTorch (Part 1)](https://sebastianraschka.com/blog/2022/losses-learned-part1.html)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76f09764-f357-4710-ba79-13ff25ec5022",
   "metadata": {},
   "source": [
    "### 1.1 Example data for binary cross-entropy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "544481bb-209d-4d71-8189-183f9a6d8bc0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([0.7503, 0.9002, 0.6225, 0.2497, 0.0998])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "y_targets = torch.tensor([1., 1., 0., 0., 0.])\n",
    "\n",
    "logits = torch.tensor([1.1, 2.2, 0.5, -1.1, -2.2])\n",
    "probas = torch.sigmoid(logits)\n",
    "print(probas)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70fdd8b1-777b-4fd8-9122-e1ede1591f9b",
   "metadata": {},
   "source": [
    "### 1.2 Implementing the binary cross-entropy loss from scratch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "9188b69f-92be-44e4-92e9-a4d4ad6ba856",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.3518)"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def binary_logistic_loss_v1(probas, y_targets):\n",
    "    res = 0.\n",
    "    for i in range(y_targets.shape[0]):\n",
    "        if y_targets[i] == 1.:\n",
    "            res += torch.log(probas[i])\n",
    "        elif y_targets[i] == 0.:\n",
    "            res += torch.log(1-probas[i])            \n",
    "        else:\n",
    "            raise ValueError(f'Value {y_targets[i]} not allowed')\n",
    "    res *= -1\n",
    "    res /= y_targets.shape[0]\n",
    "\n",
    "    return res\n",
    "\n",
    "\n",
    "binary_logistic_loss_v1(probas, y_targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "34a96141-6650-43db-bb5e-2fa77d0666bb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.3518)"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def binary_logistic_loss_v2(probas, y_targets):\n",
    "    first = -y_targets.matmul(torch.log(probas))\n",
    "    second = -(1 - y_targets).matmul(torch.log(1 - probas))\n",
    "    return (first + second) / y_targets.shape[0]\n",
    "\n",
    "binary_logistic_loss_v2(probas, y_targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "f6b1fd3d-2ee5-4a9c-9cf1-2c0465ef9116",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "38.6 µs ± 292 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit binary_logistic_loss_v1(probas, y_targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "7379900d-8f6e-47f0-a7f4-b4815de0423d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10.6 µs ± 47.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n"
     ]
    }
   ],
   "source": [
    "%timeit binary_logistic_loss_v2(probas, y_targets)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ce72efd-c1dc-4ac4-b43b-6d321dd5f3a2",
   "metadata": {},
   "source": [
    "### 1.3 Using the binary cross-entropy loss in PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "f8b21962-2b59-4935-a78d-51656552dab6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.3518)"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bce = torch.nn.BCELoss()\n",
    "bce(probas, y_targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "684af43d-c6f0-40ca-850a-841046c053df",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.3518)"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class MyBCELoss(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    " \n",
    "    def forward(self, inputs, targets):        \n",
    "        return binary_logistic_loss_v2(inputs, targets)\n",
    "    \n",
    "    \n",
    "my_bce = MyBCELoss()\n",
    "my_bce(probas, y_targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "ef6e2dcc-4c23-4ea5-a6fd-075f10437354",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.3518)"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bce_logits = torch.nn.BCEWithLogitsLoss()\n",
    "bce_logits(logits, y_targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "086924e4-9d7f-48e2-b578-93bc09154372",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.3518)"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bce_logits(logits, y_targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "4d64ad3d-b5e4-4c97-9d91-c1376fd6f7cb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.3518)"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bce(torch.sigmoid(logits), y_targets)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6dd00f5b-17c3-4e6e-937e-6f9530d9266d",
   "metadata": {},
   "source": [
    "**Log-Sum Trick and Logsigmoid**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "c7f1a8fd-c1c9-4327-b61c-ed8f1cd0ab98",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.3518)"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def binary_logistic_loss_v2(probas, y_targets):\n",
    "    first = -y_targets.matmul(torch.log(probas))\n",
    "    second = -(1 - y_targets).matmul(torch.log(1 - probas))\n",
    "    return (first + second) / y_targets.shape[0]\n",
    "\n",
    "binary_logistic_loss_v2(probas, y_targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "33eb3842-fd08-441d-89a2-e42d11fc9aa5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.3518)"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "def binary_logistic_loss_v3(logits, y_targets):\n",
    "    first = -y_targets.matmul(F.logsigmoid(logits))\n",
    "    second = -(1 - y_targets).matmul(F.logsigmoid(logits) - logits)\n",
    "    return (first + second) / y_targets.shape[0]\n",
    "\n",
    "binary_logistic_loss_v3(logits, y_targets)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7831dd2-3d2c-4ce6-af1f-729383488bb4",
   "metadata": {},
   "source": [
    "### 1.4 PyTorch’s functional vs object-oriented API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "535174f9-121d-4fb0-95a8-b7e9b23a3495",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.3518)"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "F.binary_cross_entropy(probas, y_targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "b7a791d1-ae74-4c52-a5bc-7db553551a4a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.3518)"
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "F.binary_cross_entropy_with_logits(logits, y_targets)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96693c3d-5eea-4943-8d13-0832fd578b4d",
   "metadata": {},
   "source": [
    "### 1.5 A PyTorch loss function cheatsheet (so far)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b66bd628-04cb-4d13-8ecc-208a4ac04658",
   "metadata": {},
   "source": [
    "- Note that we use different inputs here, which is why the outputs differ from previous sections."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "f8f71c2e-182c-4d9e-984c-74128f75c34c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.4399)"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "logits = torch.tensor([ -1., 0,  1.])\n",
    "targets = torch.tensor([0.,  0., 1.])\n",
    "\n",
    "bce_logits = torch.nn.BCEWithLogitsLoss()\n",
    "bce_logits(logits, targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "a4645186-6cda-44f1-9ca5-cef8f66c0896",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.4399)"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "F.binary_cross_entropy_with_logits(logits, targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "3f4c04f6-98cc-40f4-80a0-e52f7ae82c8e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.4399)"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bce = torch.nn.BCELoss()\n",
    "bce(torch.sigmoid(logits), targets)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "514dc179-c671-4a6f-9377-46c0af3794f9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.4399)"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "F.binary_cross_entropy(torch.sigmoid(logits), targets)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
